import shap
import transformers
import nlp
import torch
import numpy as np
import scipy as sp
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6")
model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-cnn-12-6")
explainer = shap.Explainer(model,tokenizer)
news = [
"An American woman died aboard a cruise ship that docked at Rio de Janeiro on Tuesday, \
the same ship on which 86 passengers previously fell ill, according to the state-run Brazilian news agency, \
Agencia Brasil. The American tourist died aboard the MS Veendam, owned by cruise operator Holland America. \
Federal Police told Agencia Brasil that forensic doctors were investigating her death. \
The ship's doctors told police that the woman was elderly and suffered from diabetes and hypertension, \
according the agency. The other passengers came down with diarrhea prior to her death during an earlier part of the trip, \
the ship's doctors said. The Veendam left New York 36 days ago for a South America tour."
]
inputs = tokenizer(news, max_length=1024, return_tensors="pt")
summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=50)
tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
" The MS Veendam left New York 36 days ago for a South America tour . The ship's doctors told police that the woman was elderly and suffered from diabetes and hypertension . The other passengers came down with diarrhea prior to her death ."
The time is depend on the length of the article and the model itself, here will be 15 minutes on 1.1 Ghz CPU
shap_values = explainer(news)
shap.plots.text(shap_values)
0%| | 0/248 [00:00<?, ?it/s]
Partition explainer: 2it [13:30, 810.90s/it]
news2 = [
"The tower is 324 metres (1,063 ft) tall, \
about the same height as an 81-storey building, \
and the tallest structure in Paris. \
Its base is square, \
measuring 125 metres (410 ft) on each side. \
During its construction, \
the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, \
a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. \
It was the first structure to reach a height of 300 metres. \
Due to the addition of a broadcasting aerial at the top of the tower in 1957,\
it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters,\
the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct."
]
inputs = tokenizer(news2, max_length=1024, return_tensors="pt")
summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=50)
tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
' The Eiffel Tower is 324 metres (1,063 ft) tall and is the tallest structure in Paris . It is the second tallest free-standing structure in France after the Millau Viaduct . It was the first'
shap_values2 = explainer(news2)
shap.plots.text(shap_values2)
0%| | 0/248 [00:00<?, ?it/s]
Partition explainer: 2it [18:08, 1088.92s/it]